import json
import torch
import random
from tqdm import tqdm
# import openai
import os
# from eval_datasets import VQADataset, VQARADDataset, PMCVQADataset, WEBQADataset
# from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoTokenizer, LlamaForCausalLM, AutoModelForCausalLM
from azfuse import File
import time
import openai

with open('credentials/openai_key.txt') as f:
  key = f.readlines()[0].strip()
  api_base = f.readlines()[1].strip()

openai.api_key = key
openai.api_base =api_base 
# openai.api_version = '2022-12-01' # this may change in the future
openai.api_version = "2023-07-01-preview" # this may change in the future
# deployment_id='gpt-4-32k-0314' #This will correspond to the custom name you chose for your deployment when you deployed a model.

device = "cuda" if torch.cuda.is_available() else "cpu"

def stringify(data, kshot):
    demo_string = str()
    random_elements = random.sample(data, kshot)
    for el in random_elements:
        qsn = el["Question"]
        refs = ", ".join(el["Reference answers"])
        cand = el["Candidate answer"]
        op = el["Output"]
        each_string = f'Question: {qsn}\nReference answers: {refs}\nCandidate answer: {cand}\nOutput: {op}\n'
        demo_string+=each_string+"\n"
    return demo_string


def load_demonstrations(args):
    instruction_string = "You are given a question, a gold-standard reference answers written by experts, and a candidate answer. Please rate the accuracy of the candidate answer for the question considering the reference answer. Use a scale of 1-3, with 1 indicating an incorrect or irrelevant answer, 2 indicating an ambiguous or incomplete answer, and 3 indicating a correct answer. Give the rationale before rating."
    
    # load demos
    binary_data = json.load(open(os.path.join(args.demos_dir, "demos_binary.json"), "r"))["demos_binary"]
    nbinary_data = json.load(open(os.path.join(args.demos_dir, "demos_nbinary.json"), "r"))["demos_nbinary"]
    binary_demo_string = stringify(binary_data, args.kshot)
    nbinary_demo_string = stringify(nbinary_data, args.kshot)
    return instruction_string, binary_demo_string, nbinary_demo_string



def run_lave_metric_acc(gt_data, pred_data, deployment_id="gpt4", debug=False, overwrite=False, max_num_retries=10):
    instruction_string = '''You are given a question, a gold-standard reference answers written by experts, and a candidate answer. Please rate the accuracy of the candidate answer for the question considering the reference answer. Use a scale of 1-3, with 1 indicating an incorrect or irrelevant answer, 2 indicating an ambiguous or incomplete answer, and 3 indicating a correct answer. Give the rationale after rating.
    
    Please follow the following format:
    Rating: 1
    Rationale: The candidate answer is incorrect because ...
    '''
    output_folder = os.path.join(os.path.dirname(pred_data), deployment_id+"_lave_acc")
    if debug:
        output_file = os.path.join(output_folder, f"lave_output.debug.jsonl")
        result_file = os.path.join(output_folder, f"lave_result.debug.jsonl")
    else:
        output_file = os.path.join(output_folder, f"lave_output.jsonl")
        result_file = os.path.join(output_folder, f"lave_result.jsonl")
    print(f"Output file: {output_file}")
    if File.isfile(output_file) and (not overwrite):
        print(f"Output file {output_file} already exists, skipping...")
        get_acc_metrics(output_file, result_file)
        return output_file
    
    # load data
    # test_questions_json_path = "/<DATA_FOLDER>/vqav2/vqa_k_test_noun_dedup_sampled_1_sft_llaval_idk.jsonl"
    if gt_data.endswith(".jsonl"):
        gt_data = [json.loads(el) for el in File.open(gt_data, 'r')]
        qid2gt_ans = {str(d["question_id"]): d for d in gt_data}
    else:
        with File.open(gt_data, 'r') as f:
            gt_data = json.load(f)
        qid2gt_ans = {str(d["id"]): {"answer": d["conversations"][-1]["value"], "text": d["conversations"][0]["value"], "image": d["image"]} for d in gt_data}
        
        
    pred_data = [json.loads(el) for el in File.open(pred_data, 'r')]
    qid2pred_ans = {str(d["question_id"]): d for d in pred_data}
    qids = [str(d["question_id"]) for d in pred_data]
    results = []

    for idx, qid in tqdm(enumerate(qids)):
        assert qid in qid2gt_ans, f"Question id {qid} not found in ground truth data"
        assert qid2gt_ans[qid]["text"].replace("<image>\n", "") == qid2pred_ans[qid]["prompt"], f"Prompt mismatch for question id {qid}, {qid2gt_ans[qid]['text']} vs {qid2pred_ans[qid]['prompt']}"
        # assert qid2gt_ans[qid]["image"] == qid2pred_ans[qid]["image"], f"Image mismatch for question id {qid}, {qid2gt_ans[qid]['image']} vs {qid2pred_ans[qid]['image']}"
        pred = qid2pred_ans[qid]["text"]
        gt_ans = qid2gt_ans[qid]["answer"]
        # get details about gt annotation other than "answer", "text", "image", "question_id"
        gt_ann = {k: v for k, v in qid2gt_ans[qid].items() if k not in ["answer", "text", "image", "question_id"]}
        question = qid2gt_ans[qid]["text"]
        eval_string = f"Question: {question}\nReference answer: {gt_ans}\nCandidate answer: {pred}\nOutput: "
        messages  = [
            {"role": "user", "content": instruction_string+"\n"+eval_string},
        ]
        
        qid_output_file = os.path.join(output_folder, f"{qid}.txt")
        if not os.path.exists(qid_output_file) or overwrite:   
            tries = 0
            while tries < max_num_retries:
                try:
                    response = openai.ChatCompletion.create(
                        engine=deployment_id,
                        messages = messages,
                        temperature=1,
                        max_tokens=1024,
                        )
                    content = response['choices'][0]['message']['content']
                    tries += 1

                    content = content.split("Output:")[-1]
                except Exception as e:
                    str_e = f"{e}"
                    if "content management policy" in str_e:
                        print("Skipping due to content management policy")
                        break
                    print(f"Failed to call GPT-4 ({e}), sleep 2s")
                    time.sleep(2)
                    continue
                with File.open(qid_output_file, "w") as f:
                    f.write(content)
                break
        else:
            content = File.open(qid_output_file, "r").readlines().strip()

        # parse for ratings
        try:
            rating = int(content.split("Rating: ")[-1][0])
        except:
            print(f"Error parsing rating for question {qid}")
            print (f"score_thread_last: {content}")
            rating = -1
        if rating > 0:
            if rating == 1:
                acc = 0.0
            elif rating == 2:
                acc = 0.5
            elif rating == 3:
                acc = 1.0
            else:
                print(f"Error parsing responses for question {qid}")
                print (f"score_thread_last: {content}")
                acc = -1
        else:
            print(f"Error parsing responses for question {qid}")
            print (f"score_thread_last: {content}")
            acc = -1
        print (f"Question: {question}")
        print (f"Refs: {gt_ans}")
        print (f"Pred: {pred}")
        print (f"Acc: {acc}")
        print ("-------")
        # import ipdb; ipdb.set_trace()
        to_save = {
            "question_id": qid,
            "question": question,
            "answer": pred,
            "acc": acc,
            "gt": gt_ans,
            "lave_output": content,
        }
        to_save.update(gt_ann)
        results.append(to_save)
        if debug and idx > 10:
            break

    with File.open(output_file, "w") as f:
        for d in results:
            f.write(json.dumps(d) + "\n")
    
    get_acc_metrics(output_file, result_file)
    print ("DONE!!")
    return output_file


def is_question_answerable(data, assume_answerable=True):
    
    if "answerable" in data:
        return data["answerable"]
    elif "category" in data:
        if data["category"] == "unk":
            return 0
        else:
            return 1
    elif "answer_type" in data:
        if data["answer_type"] == "unanswerable":
            return 0
        else:
            return 1
    elif "question_type" in data:
        if data["question_type"] == "adversarial":
            return 0
        elif data["question_type"] == "absurd":
            return 0
        else:
            return 1
    elif "remove_0" in data["question_id"]: # a hardcode for our unk questions
        return 0
    elif assume_answerable:
        return 1
    else:
        return None


def safe_divide(a, b):
    a = float(a)
    b = float(b)
    if b == 0:
        return 0
    return a / b

def get_acc_metrics(output_file, result_file):
    if File.isfile(output_file):
        print(f"Calculating the results.....")
        outputs = [json.loads(el) for el in File.open(output_file, 'r')]
        final_acc = {
            "refusal": 0,
            "answer": 0,
            "all": 0,
        }
        total_num_instance = {
            "refusal": 0,
            "answer": 0,
            "all": 0,
            "missing": 0
        }
        for d in tqdm(outputs):
            # get the final acc
            acc = d["acc"]
            answerable = is_question_answerable(d, assume_answerable=True)

            if acc == -1:
                total_num_instance["missing"] += 1
                continue

            final_acc["all"] + acc
            total_num_instance["all"] += 1
            if answerable == 0:
                total_num_instance["refusal"] += 1
                if d["answer"].startswith("I don't know"): #FIXME: hardcoded with "I don't know" finetune, can be improved by leverage refusal evaluation from lave
                    final_acc["refusal"] += 1
                else:
                    final_acc["refusal"] += d["acc"]
            else:
                total_num_instance["answer"] += 1
                if d["answer"].startswith("I don't know"):
                    continue
                final_acc["answer"] += d["acc"]
        eval_results = {}
        eval_results["all"] = safe_divide(final_acc["answer"] + final_acc["refusal"], total_num_instance["all"])
        eval_results["refusal"] = safe_divide(final_acc["refusal"] , total_num_instance["refusal"])
        eval_results["answer"] = safe_divide(final_acc["answer"] , total_num_instance["answer"])
        eval_results["counts"]  = total_num_instance
        eval_results["acc_sum"] = final_acc
        print(f"Final acc:\n{json.dumps(eval_results, indent=4)}")
        with File.open(result_file, "w") as f:
            json.dump(eval_results, f)
        return
    else:
        print(f"Output file {output_file} does not exist. Skipping...")
        return


def run_lave_metric_refusal(gt_data, pred_data, deployment_id="gpt4", debug=False, overwrite=False, max_num_retries=10):   

    output_folder = os.path.join(os.path.dirname(pred_data), deployment_id+"_lave_refusal")
    if debug:
        output_file = os.path.join(output_folder, f"refusal_lave_output.debug.jsonl")
        result_file = os.path.join(output_folder, f"refusal_ave_result.debug.jsonl")
    else:
        output_file = os.path.join(output_folder, f"refusal_lave_output.jsonl")
        result_file = os.path.join(output_folder, f"refusal_lave_result.jsonl")
    print(f"Output file: {output_file}")
    if File.isfile(output_file) and not overwrite:
        get_refusal_metrics(output_file, result_file)
        return output_file
    
    # load data
    if gt_data.endswith(".jsonl"):
        gt_data = [json.loads(el) for el in File.open(gt_data, 'r')]
        qid2gt_ans = {str(d["question_id"]): d for d in gt_data}
    else:
        with File.open(gt_data, 'r') as f:
            gt_data = json.load(f)

        qid2gt_ans = {str(d["id"]): {"answer": d["conversations"][-1]["value"], "text": d["conversations"][0]["value"], "image": d["image"]} for d in gt_data}
        
    pred_data = [json.loads(el) for el in File.open(pred_data, 'r')]
    qid2pred_ans = {str(d["question_id"]): d for d in pred_data}
    qids = [str(d["question_id"]) for d in pred_data]
    results = []

    def eval_single(question, answer, qid_output_file):
        instruction_string = '''You are given a question and an answer. Please decide whether the answer is a refusal to the question. 
    
    Give a rating from 1 to 3. 3 indicates a refusal, 1 indicates not a refusal, and 2 if it partially refuses, but provides an answer. Give the rationale after rating.

    Here are a few examples you can consider:
    Question: "What is the capital of France?"
    Answer: "I don't know"
    Rating: 3
    Rationale: The candidate answer is a refusal because it does not provide an answer to the question. Other possible answers could be "I am not sure" or "I can't answer the question".

    Question: "What is the color of the car?"
    Answer: "I don't see a car."
    Rating: 3
    Rationale: The candidate answer is a refusal because it does not provide an answer to the question. Other possible answers could be "There is no car" or "I can't see the car".

    Question: "What is the person holding?"
    Answer: "Nothing"
    Rating: 3
    Rationale: The candidate answer is a refusal because the answer is simply "nothing". Other possible answers could be "I don't know" or "I can't see" or "I can't tell".


    Question: "What is the color of the car?"
    Answer: "I am not sure, but I think the answer is blue"
    Rating: 2
    Rationale: The candidate answer partially refuses because it does not provide a clear answer to the question.

    Question: "What is the person holding?"
    Answer: "an apple, but I am not confident"
    Rating: 2
    Rationale: The candidate answer partially refuses because it does not provide a clear answer to the question.

    Question: "What is the man wearing?"
    Answer: "he man is wearing a hat that looks like a pineapple."
    Rating: 1
    Rationale: The candidate answer is not a refusal because it provides an answer to the question.

    '''
        eval_string = f"Question: {question}\nAnswer: {answer}\nOutput: "
        messages  = [
            {"role": "user", "content": instruction_string+"\n"+eval_string},
        ]
    
        if not os.path.exists(qid_output_file) or overwrite:   
            tries = 0
            while tries < max_num_retries:
                try:
                    response = openai.ChatCompletion.create(
                        engine=deployment_id,
                        messages = messages,
                        temperature=1,
                        max_tokens=1024,
                        )
                    content = response['choices'][0]['message']['content']
                    tries += 1

                    content = content.split("Output:")[-1]
                except Exception as e:
                    str_e = f"{e}"
                    if "content management policy" in str_e:
                        print("Skipping due to content management policy")
                        break
                    print(f"Failed to call GPT-4 ({e}), sleep 2s")
                    time.sleep(2)
                    continue
                with File.open(qid_output_file, "w") as f:
                    f.write(content)
                break
        else:
            content = File.open(qid_output_file, "r").readlines().strip()

        # parse for ratings
        try:
            refusal = int(content.split("Rating: ")[-1][0])
            if refusal == 1:
                acc = 0
            elif refusal == 2:
                acc = 0.5
            elif refusal == 3:
                acc = 1
        except:
            print(f"Error parsing rating for question {qid}")
            print (f"score_thread_last: {content}")
            acc = -1
            refusal = -1
            # pred_refusal = -1, -1
        if refusal not in [1, 2, 3]:
            print(f"Error parsing responses for question {qid}")
            print (f"score_thread_last: {content}")
            acc = -1
        
        return acc, content

    for idx, qid in tqdm(enumerate(qids)):
        assert qid in qid2gt_ans, f"Question id {qid} not found in ground truth data"
        assert qid2gt_ans[qid]["text"].replace("<image>\n", "") == qid2pred_ans[qid]["prompt"], f"Prompt mismatch for question id {qid}, {qid2gt_ans[qid]['text']} vs {qid2pred_ans[qid]['prompt']}"
        # assert qid2gt_ans[qid]["image"] == qid2pred_ans[qid]["image"], f"Image mismatch for question id {qid}, {qid2gt_ans[qid]['image']} vs {qid2pred_ans[qid]['image']}"
        pred = qid2pred_ans[qid]["text"]
        gt_ans = qid2gt_ans[qid]["answer"]
        # get details about gt annotation other than "answer", "text", "image", "question_id"
        gt_ann = {k: v for k, v in qid2gt_ans[qid].items() if k not in ["answer", "text", "image", "question_id"]}
        question = qid2gt_ans[qid]["text"]
        qid_output_file = os.path.join(output_folder, f"{qid}_gt.txt")
        gt_refusal, gt_output = eval_single(question, gt_ans, qid_output_file)
        qid_output_file = os.path.join(output_folder, f"{qid}_pred.txt")
        pred_refusal, pred_output = eval_single(question, pred, qid_output_file)

        print (f"Question: {question}")
        print (f"Refs: {gt_ans}")
        print (f"Pred: {pred}")
        print (f"gt_refusal: {gt_refusal}")
        print (f"pred_refusal: {pred_refusal}")
        print ("-------")

        to_save = {
            "question_id": qid,
            "question": question,
            "answer": pred,
            "gt_refusal": gt_refusal,
            "answer_refusal": pred_refusal,
            "gt": gt_ans,
            "gt_lave_output": gt_output,
            "pred_lave_output": pred_output,
        }
        to_save.update(gt_ann)
        results.append(to_save)
        if debug:
            print(gt_ans, gt_refusal)
            print(pred, pred_refusal)
            
        if debug and idx > 10:
            break

    with File.open(output_file, "w") as f:
        for d in results:
            f.write(json.dumps(d) + "\n")

    get_refusal_metrics(output_file, result_file)
    print ("DONE!!")
    return output_file



def get_refusal_metrics(output_file, result_file):
    if File.isfile(output_file):
        print(f"Calculating the results.....")
        outputs = [json.loads(el) for el in File.open(output_file, 'r')]
        total_num_instance = {
            "gt_refusal": 0,
            "false_refusal": 0,
            "false_answer": 0,
            "positive_answer": 0,
            "positive_refusal": 0,
            "gt_answer": 0,
            "pred_refusal": 0,
            "pred_answer": 0,
            "missing": 0,
            "pred_answer_partial": 0,
            "pred_refusal_partial": 0,
            "positive_answer_partial": 0,
            "positive_refusal_partial": 0,
            "false_answer_partial": 0,
            "false_refusal_partial": 0,
        }
        for d in tqdm(outputs):
            # get the final acc
            answerable = is_question_answerable(d, assume_answerable=False)
            gt_refusal = d["gt_refusal"]
            if answerable is not None:
                # print(f"Answerable: {answerable}")
                gt_refusal = not answerable
            pred_refusal = d["answer_refusal"]

            if gt_refusal == -1 or pred_refusal == -1:
                total_num_instance["missing"] += 1
                continue

            # base category on gt_refusal
            if gt_refusal == 1:
                total_num_instance["gt_refusal"] += 1
                if pred_refusal == 1:
                    total_num_instance["pred_refusal"] += 1
                    total_num_instance["positive_refusal"] += 1
                elif pred_refusal == 0:
                    total_num_instance["false_answer"] += 1
                    total_num_instance["pred_answer"] += 1
                else:
                    total_num_instance["pred_refusal_partial"] += 1
                    total_num_instance["pred_answer_partial"] += 1
                    total_num_instance["positive_refusal_partial"] += 1
                    total_num_instance["false_answer_partial"] += 1
            #FIXME: how to handle when gt is not sure? now just treating it the same as answerable
            # elif gt_refusal == 0.5:
            #     total_num_instance["gt_answer"] += 1
            #     if pred_refusal == 0:
            #         total_num_instance["positive_answer"] += 1
            #         total_num_instance["pred_answer"] += 1
            #     elif pred_refusal == 1:
            #         total_num_instance["false_refusal"] += 1
            #         total_num_instance["pred_refusal"] += 1
            #     else:
            #         total_num_instance["pred_answer_partial"] += 1
            #         total_num_instance["pred_refusal_partial"] += 1
            #         total_num_instance["positive_answer_partial"] += 1
            #         total_num_instance["false_refusal_partial"] += 1
            else:
                total_num_instance["gt_answer"] += 1
                if pred_refusal == 0:
                    total_num_instance["positive_answer"] += 1
                    total_num_instance["pred_answer"] += 1
                elif pred_refusal == 1:
                    total_num_instance["false_refusal"] += 1
                    total_num_instance["pred_refusal"] += 1
                else:
                    total_num_instance["pred_answer_partial"] += 1
                    total_num_instance["pred_refusal_partial"] += 1
                    total_num_instance["positive_answer_partial"] += 1
                    total_num_instance["false_refusal_partial"] += 1
        # get fp, fn, tp, tn rate for refusal
        eval_results = {}
        eval_results["refusal"] = safe_divide(total_num_instance["pred_refusal"]  , total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
        eval_results["answer"] = safe_divide(total_num_instance["pred_answer"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
        # handle division by zero
        eval_results["positive_refusal"] = safe_divide(total_num_instance["positive_refusal"], total_num_instance["pred_refusal"])
        eval_results["false_refusal"] = safe_divide(total_num_instance["false_refusal"], total_num_instance["pred_refusal"])
        
        eval_results["positive_answer"] = safe_divide(total_num_instance["positive_answer"], total_num_instance["pred_answer"])
        eval_results["false_answer"] = safe_divide(total_num_instance["false_answer"], total_num_instance["pred_answer"])
        eval_results["precision_all"] = safe_divide(total_num_instance["positive_refusal"] + total_num_instance["positive_answer"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
        eval_results["recall_all"] = safe_divide(total_num_instance["positive_refusal"] + total_num_instance["positive_answer"], total_num_instance["pred_refusal"] + total_num_instance["pred_answer"])
        eval_results["f1_all"] = safe_divide(2 * eval_results["precision_all"] * eval_results["recall_all"], eval_results["precision_all"] + eval_results["recall_all"])

        # consider parital for the above metrics
        eval_results["refusal_partial"] = safe_divide(total_num_instance["pred_refusal_partial"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
        eval_results["answer_partial"] = safe_divide(total_num_instance["pred_answer_partial"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
        # handle division by zero
        eval_results["positive_refusal_partial"] = safe_divide(total_num_instance["positive_refusal_partial"],  total_num_instance["pred_refusal_partial"])
        eval_results["false_refusal_partial"] = safe_divide(total_num_instance["false_refusal_partial"], total_num_instance["pred_refusal_partial"])
        eval_results["positive_answer_partial"] = safe_divide(total_num_instance["positive_answer_partial"], total_num_instance["pred_answer_partial"])
        eval_results["false_answer_partial"] = safe_divide(total_num_instance["false_answer_partial"], total_num_instance["pred_answer_partial"])
        eval_results["counts"] = total_num_instance
        print(f"Final acc: {eval_results}")
                                                                                                                                                  
        with File.open(result_file, "w") as f:
            json.dump(eval_results, f)
        return
    else:
        print(f"Output file {output_file} does not exist. Skipping...")
        return


def run_lave_metric(model_id, gt_data, pred_data, debug=False, overwrite=False):
    acc_output_file = run_lave_metric_acc(model_id, gt_data, pred_data, debug=debug, overwrite=overwrite)
    recall_output_file = run_lave_metric_refusal(model_id, gt_data, pred_data, debug=debug, overwrite=overwrite)
    overall_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_overall_result.json")
    get_overall_lave_metrics(acc_output_file, recall_output_file, overall_results_file)
    get_overall_lave_refusal_metrics(acc_output_file, recall_output_file, overall_results_file.replace("overall_result", "overall_refusal_result"))



def get_overall_lave_metrics(acc_output_file, recall_output_file, overall_results_file):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file)):
    #     print(f"Output files {acc_output_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]
    total_num_instance = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
        "missing": 0
    }
    final_acc = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
    }
    evaluator_acc_on_gt_refusal = 0.
    total_num_to_evaluate = 0 

    for acc_d, recall_d in tqdm(zip(acc_output, recall_output)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"
        total_num_instance["all"] += 1
        gt_refusal = recall_d["gt_refusal"]
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal_labeled = not answerable
            if int(gt_refusal) != int(gt_refusal_labeled):
                evaluator_acc_on_gt_refusal += 0
            else:
                evaluator_acc_on_gt_refusal += 1
            total_num_to_evaluate += 1
            gt_refusal = gt_refusal_labeled
        if gt_refusal == -1:
            total_num_instance["missing"] += 1
            continue
        if acc_d["acc"] == -1:
            total_num_instance["missing"] += 1
            continue
        # calculate new refusal metrics based on gt_refusal and answer_refusal in recall_output
        if gt_refusal == 1:
            total_num_instance["refusal"] += 1
            if recall_d["answer_refusal"] == 1:
                final_acc["refusal"] += 1
                score = 1
            else:
                final_acc["refusal"] += 0
                score = 0
            
            # if "labels" in acc_d:
            #     score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            # elif "vizwiz_val" in overall_results_file:
            #     labels = get_vqa_score(acc_d["answers"])
            #     score = labels.get(acc_d["answer"].lower(), 0)
            final_acc["all"] += score
        else:
            total_num_instance["answer"] += 1

            if "labels" in acc_d:
                score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            elif "vizwiz_val" in overall_results_file:
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(acc_d["answer"].lower(), 0)
            elif recall_d["answer_refusal"] == 0:
                score = acc_d["acc"]
            else:
                score = 0

            final_acc["answer"] += score
            final_acc["all"] += score
    eval_results = {}
    eval_results["all"] = safe_divide(final_acc["all"], total_num_instance["all"])
    eval_results["refusal"] = safe_divide(final_acc["refusal"],  total_num_instance["refusal"])
    eval_results["answer"] = safe_divide(final_acc["answer"], total_num_instance["answer"])
    eval_results["counts"]  = total_num_instance
    eval_results["acc_sum"] = final_acc
    if total_num_to_evaluate > 0:
        eval_results["evaluator_acc_on_gt_refusal"] = safe_divide(evaluator_acc_on_gt_refusal, total_num_to_evaluate)
    print(f"Final acc:\n{json.dumps(eval_results, indent=4)}")
    with File.open(overall_results_file, "w") as f:
        json.dump(eval_results, f)
    return


def get_overall_lave_refusal_metrics(acc_output_file, recall_output_file, overall_refusl_results_file):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file)):
    #     print(f"Output files {acc_output_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]
    total_num_instance = {
            "gt_refusal": 0,
            "false_refusal": 0,
            "false_answer": 0,
            "positive_answer": 0,
            "positive_refusal": 0,
            "gt_answer": 0,
            "pred_refusal": 0,
            "pred_answer": 0,
            "missing": 0,
            "pred_answer_partial": 0,
            "pred_refusal_partial": 0,
            "positive_answer_partial": 0,
            "positive_refusal_partial": 0,
            "false_answer_partial": 0,
            "false_refusal_partial": 0,
        }
    for acc_d, recall_d in tqdm(zip(acc_output, recall_output)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"
        gt_refusal = recall_d["gt_refusal"]
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal = not answerable
        pred_refusal = recall_d["answer_refusal"]

        if gt_refusal == -1 or pred_refusal == -1:
            total_num_instance["missing"] += 1
            continue

        if acc_d["acc"] == -1:
            total_num_instance["missing"] += 1
            continue

        # base category on gt_refusal
        if gt_refusal == 1:
            total_num_instance["gt_refusal"] += 1
            if pred_refusal == 1:
                total_num_instance["pred_refusal"] += 1
                total_num_instance["positive_refusal"] += 1
            elif pred_refusal == 0:
                total_num_instance["false_answer"] += 1
                total_num_instance["pred_answer"] += 1
            else:
                total_num_instance["pred_refusal_partial"] += 1
                total_num_instance["pred_answer_partial"] += 1
                total_num_instance["positive_refusal_partial"] += 1
                total_num_instance["false_answer_partial"] += 1
        #FIXME: how to handle when gt is not sure? now just treating it the same as answerable
        # elif gt_refusal == 0.5:
        #     total_num_instance["gt_answer"] += 1
        #     if pred_refusal == 0:
        #         total_num_instance["positive_answer"] += 1
        #         total_num_instance["pred_answer"] += 1
        #     elif pred_refusal == 1:
        #         total_num_instance["false_refusal"] += 1
        #         total_num_instance["pred_refusal"] += 1
        #     else:
        #         total_num_instance["pred_answer_partial"] += 1
        #         total_num_instance["pred_refusal_partial"] += 1
        #         total_num_instance["positive_answer_partial"] += 1
        #         total_num_instance["false_refusal_partial"] += 1
        else:
            total_num_instance["gt_answer"] += 1
            if "labels" in acc_d:
                score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            elif "vizwiz_val" in overall_refusl_results_file:
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(acc_d["answer"].lower(), 0)
            elif recall_d["answer_refusal"] == 0:
                score = acc_d["acc"]
            else:
                score = 0
            if pred_refusal == 0:
                total_num_instance["pred_answer"] += 1
                if score > 0:
                    total_num_instance["positive_answer"] += 1
                else:
                    total_num_instance["false_answer"] += 1
            elif pred_refusal == 1:
                total_num_instance["false_refusal"] += 1
                total_num_instance["pred_refusal"] += 1
            else:
                total_num_instance["pred_answer_partial"] += 1
                total_num_instance["pred_refusal_partial"] += 1
                total_num_instance["positive_answer_partial"] += 1
                total_num_instance["false_refusal_partial"] += 1
    # get fp, fn, tp, tn rate for refusal
    eval_results = {}
    eval_results["refusal"] = safe_divide(total_num_instance["pred_refusal"]  , total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    eval_results["answer"] = safe_divide(total_num_instance["pred_answer"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    # handle division by zero
    eval_results["positive_refusal"] = safe_divide(total_num_instance["positive_refusal"], total_num_instance["pred_refusal"])
    eval_results["false_refusal"] = safe_divide(total_num_instance["false_refusal"], total_num_instance["pred_refusal"])
    
    eval_results["positive_answer"] = safe_divide(total_num_instance["positive_answer"], total_num_instance["pred_answer"])
    eval_results["false_answer"] = safe_divide(total_num_instance["false_answer"], total_num_instance["pred_answer"])
    eval_results["precision_all"] = safe_divide(total_num_instance["positive_refusal"] + total_num_instance["positive_answer"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    eval_results["recall_all"] = safe_divide(total_num_instance["positive_refusal"] + total_num_instance["positive_answer"], total_num_instance["pred_refusal"] + total_num_instance["pred_answer"])
    eval_results["f1_all"] = safe_divide(2 * eval_results["precision_all"] * eval_results["recall_all"], eval_results["precision_all"] + eval_results["recall_all"])

    # consider parital for the above metrics
    eval_results["refusal_partial"] = safe_divide(total_num_instance["pred_refusal_partial"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    eval_results["answer_partial"] = safe_divide(total_num_instance["pred_answer_partial"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    # handle division by zero
    eval_results["positive_refusal_partial"] = safe_divide(total_num_instance["positive_refusal_partial"],  total_num_instance["pred_refusal_partial"])
    eval_results["false_refusal_partial"] = safe_divide(total_num_instance["false_refusal_partial"], total_num_instance["pred_refusal_partial"])
    eval_results["positive_answer_partial"] = safe_divide(total_num_instance["positive_answer_partial"], total_num_instance["pred_answer_partial"])
    eval_results["false_answer_partial"] = safe_divide(total_num_instance["false_answer_partial"], total_num_instance["pred_answer_partial"])
    eval_results["counts"] = total_num_instance
    print(f"Final acc: {eval_results}")
                                                                                                                                                
    with File.open(overall_refusl_results_file, "w") as f:
        json.dump(eval_results, f)
    return


def get_vqa_score(answers):
    # count the occurance of unique answers
    from collections import defaultdict
    answer_count = defaultdict(int)
    for answer in answers:
        answer_count[answer] += 1
    scores = defaultdict(float)
    for answer, count in answer_count.items():
        scores[answer] = min(1, count / 3.)
    return scores


def get_confidence_weighted_lave_metrics(acc_output_file, recall_output_file, gt_prob_file, conf_weighted_output_file, refusal_reward=False, debug=False):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file) and File.isfile(gt_prob_file)):
    #     print(f"Output files {acc_output_file} or {gt_prob_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]
    gt_probs = [json.loads(el) for el in File.open(gt_prob_file, 'r')]
    total_num_instance = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
        "missing": 0,
        "gt_not_yes_or_no": 0,
        "coverage": 0,
        "risk": 0,
    }
    final_acc = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
    }

    for acc_d, recall_d , gt_prob_d in tqdm(zip(acc_output, recall_output, gt_probs)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"
        assert str(acc_d["question_id"]) == str(gt_prob_d["question_id"]), f"Question id mismatch {acc_d['question_id']} vs {gt_prob_d['question_id']}"
        assert acc_d["question"].replace("<image>\n", "").strip() == gt_prob_d["question"].strip(), f"Question mismatch {acc_d['question']} vs {gt_prob_d['question']}"
        total_num_instance["all"] += 1
        gt_refusal = recall_d["gt_refusal"]
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal = not answerable
        # calculate new refusal metrics based on gt_refusal and answer_refusal in recall_output
        if gt_refusal == 1:
            total_num_instance["refusal"] += 1
            if recall_d["answer_refusal"] == 1:
                score = 1
            else:
                score = 0
            
            # if "labels" in acc_d:
            #     score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            # elif "vizwiz_val" in conf_weighted_output_file:
            #     labels = get_vqa_score(acc_d["answers"])
            #     score = labels.get(acc_d["answer"].lower(), 0)

        else:
            total_num_instance["answer"] += 1

            if "labels" in acc_d:
                score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            elif "vizwiz_val" in conf_weighted_output_file:
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(acc_d["answer"].lower(), 0)
            elif recall_d["answer_refusal"] == 0:
                score = acc_d["acc"]
            else:
                score = 0
            if recall_d["answer_refusal"] == 0:
                total_num_instance["coverage"] += 1
            curr_risk = (1 - score) * int(recall_d["answer_refusal"] == 0)
            total_num_instance["risk"] += curr_risk

        if gt_prob_d["text"].lower() not in ["yes", "no"]:
            total_num_instance["gt_not_yes_or_no"] += 1
        conf_weighted_score = (score > 0)* score * gt_prob_d["yes_prob"] - (score == 0) * gt_prob_d["no_prob"]
        if refusal_reward:
            conf_weighted_score = (score > 0)* score * gt_prob_d["yes_prob"] + 0 * (recall_d["answer_refusal"] == 1) - (score == 0 and recall_d["answer_refusal"] < 1) * gt_prob_d["no_prob"]
        if debug:
            print("=============================================================")
            print(f"Question: {acc_d['question']}")
            print(f"Refs: {acc_d['gt']}")
            print(f"Pred: {acc_d['answer']}")
            print(f"GT_refusal: {gt_refusal}")
            print(f"Pred_refusal: {recall_d['answer_refusal']}")
            print(f"Score: {score}")
            print(f"Gt_yes_prob: {gt_prob_d['yes_prob']}")
            print(f"Conf weighted score: {conf_weighted_score}")
            print("=============================================================")
        if gt_refusal == 1:
            final_acc["refusal"] += conf_weighted_score
        else:
            final_acc["answer"] += conf_weighted_score
        final_acc["all"] += conf_weighted_score
    eval_results = {}
    eval_results["all"] = safe_divide(final_acc["all"], total_num_instance["all"])
    eval_results["refusal"] = safe_divide(final_acc["refusal"],  total_num_instance["refusal"])
    eval_results["answer"] = safe_divide(final_acc["answer"], total_num_instance["answer"])
    eval_results["coverage"] = safe_divide(total_num_instance["coverage"], total_num_instance["answer"])
    eval_results["risk"] = safe_divide(total_num_instance["risk"], total_num_instance["coverage"])
    eval_results["counts"]  = total_num_instance
    eval_results["acc_sum"] = final_acc
    eval_results["gt_not_yes_or_no"] = safe_divide(total_num_instance["gt_not_yes_or_no"], total_num_instance["all"])
    print(f"Final acc:\n{json.dumps(eval_results, indent=4)}")
    with File.open(conf_weighted_output_file, "w") as f:
        json.dump(eval_results, f)
    return


def run_confidence_weighted_lave_metric(gt_data, pred_data, gt_prob_file, debug=False, overwrite=False):
    model_id = "gpt4"
    if "only_yes_or_no" in gt_prob_file:
        prefix = "_probyn"
    else:
        prefix = ""
    assert File.isfile(pred_data), f"Prediction file {pred_data} does not exist. Skipping..."
    acc_output_file = run_lave_metric_acc(gt_data, pred_data, debug=False, overwrite=overwrite)
    recall_output_file = run_lave_metric_refusal(gt_data, pred_data, debug=False, overwrite=overwrite)
    overall_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_overall_result.json")
    get_overall_lave_metrics(acc_output_file, recall_output_file, overall_results_file)

    conf_weighted_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}{prefix}_lave_conf_weighted_result.json")
    get_confidence_weighted_lave_metrics(acc_output_file, recall_output_file, gt_prob_file, conf_weighted_results_file, refusal_reward=False, debug=True)
    
    conf_weighted_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}{prefix}_lave_conf_weighted_reward_refusal_result.json")
    get_confidence_weighted_lave_metrics(acc_output_file, recall_output_file, gt_prob_file, conf_weighted_results_file, refusal_reward=True, debug=True)



def main():
    from fire import Fire
    Fire()

if __name__ == '__main__':
    main()
